Flax schedulers: replace control flow with functions#577
Closed
Flax schedulers: replace control flow with functions#577
Conversation
Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`.
|
The documentation is not available anymore as the PR was closed or merged. |
Member
Author
|
We can test with just import jax
import jax.numpy as jnp
from diffusers import FlaxPNDMScheduler
scheduler = FlaxPNDMScheduler.from_config(PATH_TO_SCHEDULER_DIR)
latents_shape = (1, 64, 64, 3)
scheduler_state = scheduler.set_timesteps(
scheduler.state,
shape = latents_shape, # Needs to be known in advance to reserve space
num_inference_steps = 50,
)
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
latents = jax.random.normal(key1, shape=latents_shape, dtype=jnp.float32)
noise = jax.random.normal(key2, shape=latents_shape, dtype=jnp.float32)
p_step = jax.jit(scheduler.step, static_argnums=4)
latents, scheduler_state = p_step(scheduler_state, noise, 37, latents, return_dict=False)This example should work with both |
Contributor
|
Cool! Will focus on DDIM for now to get the pipeline working with it |
Merged
Member
Author
|
Replaced by #583 for PNDM. We'll open separate PRs for others. |
PhaneeshB
pushed a commit
to nod-ai/diffusers
that referenced
this pull request
Mar 1, 2023
* Minor fixes to benchmark runner * Add Mnasnet to tank.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects.
I started with
FlaxPNDMScheduler, and temporarily removedstep_prk.